# -*- coding: utf-8 -*-
"""
DPO on ChatGLM3-6B — reference-free（不传 ref_model）
- 默认 FP16 + LoRA（更稳更快）；需要时可 --qlora 开启 4bit
- 12GB 显存友好：512 序列、LoRA r=8、梯度累积默认 8、每步日志
"""

import os
# ★ 提速：仅禁用 flash，不再禁用 mem_efficient（让 PyTorch 用更快的 mem_efficient SDPA）
os.environ.setdefault("PYTORCH_SDP_DISABLE_FLASH_ATTENTION", "1")
if "PYTORCH_SDP_DISABLE_MEM_EFFICIENT" in os.environ:
    del os.environ["PYTORCH_SDP_DISABLE_MEM_EFFICIENT"]

import argparse
import json
import torch
from dataclasses import dataclass
from typing import Dict, Any

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

from trl import DPOTrainer
from trl.trainer.dpo_config import DPOConfig
from peft import LoraConfig

# ----------------------
# 命令行参数
# ----------------------
def get_args():
    p = argparse.ArgumentParser()
    p.add_argument("--model_path", type=str, default=r"D:\chatglm3-models\chatglm3-6b")
    p.add_argument("--data_jsonl", type=str, default=r"D:\chongxinxuexi\pythonProject1\data\public_pairs.jsonl")
    p.add_argument("--output_dir", type=str, default=r"D:\chongxinxuexi\pythonProject1\dpo_model\public")

    p.add_argument("--epochs", type=int, default=1)
    p.add_argument("--lr", type=float, default=1e-5)
    p.add_argument("--beta", type=float, default=0.3)
    p.add_argument("--batch", type=int, default=1)            # per_device_train_batch_size
    p.add_argument("--accum", type=int, default=1)            # 建议 1/8/16
    p.add_argument("--max_length", type=int, default=700)
    p.add_argument("--max_prompt_length", type=int, default=256)

    p.add_argument("--lora_r", type=int, default=8)
    p.add_argument("--lora_alpha", type=int, default=32)
    p.add_argument("--lora_dropout", type=float, default=0.05)

    p.add_argument("--qlora", action="store_true", help="启用 4bit QLoRA（Windows 上可能很慢）")
    p.add_argument("--save_each_epoch", action="store_true")
    p.add_argument("--max_steps", type=int, default=-1, help="先小步验证，如 200；<=0 则按 epochs 跑")
    p.add_argument("--seed", type=int, default=42)

    p.add_argument("--num_workers", type=int, default=2)
    p.add_argument("--pin_memory", action="store_true")
    return p.parse_args()

# ----------------------
# 实用函数
# ----------------------
def ensure_keys(example: Dict[str, Any]):
    need = {"prompt", "chosen", "rejected"}
    missing = [k for k in need if k not in example]
    if missing:
        raise ValueError(f"数据样本缺少字段：{missing}；需要包含 {sorted(list(need))}")

def print_step(msg: str):
    print(f"[INFO] {msg}", flush=True)

# ----------------------
# 主流程
# ----------------------
def main():
    args = get_args()
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
        torch.backends.cuda.matmul.allow_tf32 = True  # Ampere+ 提速
        try:
            torch.set_float32_matmul_precision("high")
        except Exception:
            pass
        # ★ 提速：启用 mem_efficient，关闭 flash 和 math（更快）
        try:
            from torch.backends.cuda import sdp_kernel
            sdp_kernel.enable_flash(False)
            sdp_kernel.enable_mem_efficient(True)
            sdp_kernel.enable_math(False)
            print_step("已设置 SDP: 使用 mem_efficient 注意力（更快；不稳定再切回 math）")
        except Exception:
            pass

    # ---- BitsAndBytes 配置（按需）
    bnb_config = None
    if args.qlora:
        try:
            from transformers import BitsAndBytesConfig
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_use_double_quant=True,
                bnb_4bit_compute_dtype=torch.float16,
            )
            print_step("✅ QLoRA 已启用（4bit 量化）")
        except Exception as e:
            print_step(f"⚠️ 启用 QLoRA 失败，将回退到 FP16（原因：{repr(e)}）")
            bnb_config = None
    else:
        print_step("✅ 使用 FP16 + LoRA（默认更稳更快）")

    # ---- Tokenizer
    print_step("Step 1: 开始加载 Tokenizer")
    tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
    # 保持你当前右填充设置
    tokenizer.padding_side = "right"
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    print_step("Step 2: Tokenizer 加载完成")

    # ---- Model
    print_step("Step 3: 开始加载模型")
    model = AutoModelForCausalLM.from_pretrained(
        args.model_path,
        trust_remote_code=True,
        quantization_config=bnb_config if bnb_config is not None else None,
        torch_dtype=torch.float16,
        device_map={"": 0},          # 整模放到 cuda:0
    )
    model.gradient_checkpointing_enable()
    model.config.use_cache = False
    model.config.pad_token_id = tokenizer.pad_token_id
    model.config.eos_token_id = tokenizer.eos_token_id
    print_step("Step 4: 模型加载完成")

    # ---- LoRA
    peft_config = LoraConfig(
        r=args.lora_r,
        lora_alpha=args.lora_alpha,
        lora_dropout=args.lora_dropout,
        target_modules=["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"],
        bias="none",
        task_type="CAUSAL_LM",
    )

    # ---- Dataset
    print_step("Step 5: 开始加载数据集")
    dataset_all = load_dataset("json", data_files=args.data_jsonl)["train"]
    if len(dataset_all) == 0:
        raise RuntimeError("数据集为空，请检查 --data_jsonl 路径或文件内容。")

    ensure_keys(dataset_all[0])

    # 不再切分，全部作为训练集
    train_dataset = dataset_all
    eval_dataset = None

    print_step(f"Step 6: 数据加载完成，train={len(train_dataset)}（本次不创建评估集）")

    # ---- Training args
    logging_steps = 10  # ★ 提速：降低日志频率，减少 I/O
    save_strategy = "epoch" if args.save_each_epoch else "no"

    use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()

    training_args = DPOConfig(
        output_dir=args.output_dir,
        per_device_train_batch_size=args.batch,
        gradient_accumulation_steps=args.accum,
        num_train_epochs=args.epochs,
        learning_rate=args.lr,
        beta=args.beta,
        max_length=args.max_length,
        max_prompt_length=args.max_prompt_length,
        remove_unused_columns=False,
        logging_steps=logging_steps,
        disable_tqdm=False,
        eval_strategy="no",              # ★ 提速：关闭训练过程中的评估
        save_strategy=save_strategy,
        fp16=(not use_bf16) and (bnb_config is None),
        bf16=use_bf16,
        report_to=["tensorboard"],
        logging_dir=os.path.join(args.output_dir, "tb"),  # 日志目录
        dataloader_num_workers=args.num_workers,
        dataloader_pin_memory=args.pin_memory,

        max_grad_norm=1.0,
    )

    # ---- Trainer
    trainer = DPOTrainer(
        model=model,
        ref_model=None,
        args=training_args,
        processing_class=tokenizer,
        train_dataset=train_dataset,
        # eval_dataset=eval_dataset,   # ★ 提速：不传 eval，避免频繁评估
        peft_config=peft_config,
    )

    # ---- 训练
    print_step("Step 7: 开始训练")
    if args.max_steps and args.max_steps > 0:
        trainer.train(max_steps=args.max_steps)
    else:
        trainer.train()

    # ---- 保存
    print_step("Step 8: 保存模型与 Tokenizer")
    trainer.save_model(args.output_dir)
    tokenizer.save_pretrained(args.output_dir)
    trainer.save_state()

    print_step(f"✅ 训练完成，模型保存在：{args.output_dir}")

if __name__ == "__main__":
    main()
